/*
 * Copyright (C) 2015 The Pennsylvania State University and the University of Wisconsin
 * Systems and Internet Infrastructure Security Laboratory
 *
 * Author: Damien Octeau
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package edu.psu.cse.siis.ic3.db;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public abstract class TwoIntTable extends Table {
  private static final String INSERT = "INSERT INTO \"%s\" (%s, %s) VALUES (?, ?)";
  private static final String FIND = "SELECT id FROM \"%s\" WHERE %s = ? AND %s = ?";
  private static final String BATCH_INSERT = INSERT;
  private static final String BATCH_FIND = "SELECT %s, %s, %s FROM \"%s\" WHERE 1 = 0";

  private final String firstColumn;
  private final String secondColumn;
  private final String batchInsertPattern;
  private final String batchFindPattern;

  TwoIntTable(String table, String firstColumn, String secondColumn) {
    insertString = String.format(INSERT, table, firstColumn, secondColumn);
    findString = String.format(FIND, table, firstColumn, secondColumn);
    batchInsertString = String.format(BATCH_INSERT, table, firstColumn, secondColumn);
    batchFindString = String.format(BATCH_FIND, ID, firstColumn, secondColumn, table);
    this.firstColumn = firstColumn;
    this.secondColumn = secondColumn;
    this.batchInsertPattern = String.format(", (?, ?)");
    this.batchFindPattern = String.format(" OR (%s = ? AND %s = ?)", firstColumn, secondColumn);
  }

  public Map<Pair<Integer, Integer>, Integer> batchFind(Set<Pair<Integer, Integer>> values)
      throws SQLException {
    Map<Pair<Integer, Integer>, Integer> found = new HashMap<Pair<Integer, Integer>, Integer>();
    if (values == null || values.size() == 0) {
      return found;
    }
    StringBuilder queryBuilder = new StringBuilder(batchFindString);
    for (int i = 0; i < values.size(); ++i) {
      queryBuilder.append(batchFindPattern);
    }
    PreparedStatement batchFindStatement =
        getConnection().prepareStatement(queryBuilder.toString());
    int parameterIndex = 1;
    for (Pair<Integer, Integer> value : values) {
      batchFindStatement.setInt(parameterIndex++, value.getO1());
      batchFindStatement.setInt(parameterIndex++, value.getO2());
    }

    ResultSet resultSet = batchFindStatement.executeQuery();
    while (resultSet.next()) {
      found.put(
          new Pair<Integer, Integer>(resultSet.getInt(firstColumn), resultSet.getInt(secondColumn)),
          resultSet.getInt(ID));
    }
    return found;
  }

  public Set<Integer> batchInsert(Set<Pair<Integer, Integer>> values) throws SQLException {
    Map<Pair<Integer, Integer>, Integer> found = batchFind(values);
    Set<Pair<Integer, Integer>> toBeInserted = new HashSet<Pair<Integer, Integer>>(values);
    // Take the set difference. Obtain the values which have not been found;
    toBeInserted.removeAll(found.keySet());
    Set<Integer> result = batchForceInsert(toBeInserted);
    result.addAll(found.values());

    return result;
  }

  public Set<Integer> batchForceInsert(Set<Pair<Integer, Integer>> values) throws SQLException {
    Set<Integer> result = new HashSet<Integer>();

    if (values.size() > 0) {
      StringBuilder queryBuilder = new StringBuilder(batchInsertString);
      for (int i = 1; i < values.size(); ++i) {
        queryBuilder.append(batchInsertPattern);
      }
      PreparedStatement batchInsertStatement =
          getConnection().prepareStatement(queryBuilder.toString(), AUTOGENERATED_ID);
      int parameterIndex = 1;
      for (Pair<Integer, Integer> value : values) {
        batchInsertStatement.setInt(parameterIndex++, value.getO1());
        batchInsertStatement.setInt(parameterIndex++, value.getO2());
      }

      batchInsertStatement.executeUpdate();
      ResultSet resultSet = batchInsertStatement.getGeneratedKeys();
      while (resultSet.next()) {
        result.add(resultSet.getInt(1));
      }
    }

    return result;
  }

  public Set<Integer> batchForceInsert(Integer firstValue, List<Integer> values)
      throws SQLException {
    Set<Pair<Integer, Integer>> newValues = new HashSet<Pair<Integer, Integer>>();
    if (values == null || values.size() == 0) {
      return new HashSet<Integer>();
    }
    for (int value : values) {
      newValues.add(new Pair<Integer, Integer>(firstValue, value));
    }

    return batchForceInsert(newValues);
  }

  public int insert(Integer firstValue, Integer secondValue) throws SQLException {
    int id = find(firstValue, secondValue);
    if (id != NOT_FOUND) {
      return id;
    }
    return forceInsert(firstValue, secondValue);
  }

  public int forceInsert(Integer firstValue, Integer secondValue) throws SQLException {
    if (insertStatement == null || insertStatement.isClosed()) {
      insertStatement = getConnection().prepareStatement(insertString + " returning id;");
    }
    if (firstValue != null) {
      insertStatement.setInt(1, firstValue);
    } else {
      insertStatement.setNull(1, Types.INTEGER);
    }
    if (secondValue != null) {
      insertStatement.setInt(2, secondValue);
    } else {
      insertStatement.setNull(2, Types.INTEGER);
    }
    // if (insertStatement.executeUpdate() == 0) {
    // return NOT_FOUND;
    // }
    // return findAutoIncrement();
    return processIntFindQuery(insertStatement);
  }

  public int find(Integer firstValue, Integer secondValue) throws SQLException {
    if (findStatement == null || findStatement.isClosed()) {
      findStatement = getConnection().prepareStatement(findString);
    }
    if (firstValue != null) {
      findStatement.setInt(1, firstValue);
    } else {
      findStatement.setNull(1, Types.INTEGER);
    }
    if (secondValue != null) {
      findStatement.setInt(2, secondValue);
    } else {
      findStatement.setNull(2, Types.INTEGER);
    }
    return processIntFindQuery(findStatement);
  }
}
